package cn.zifangsky.hashtable;

import java.util.Set;
import java.util.function.BiConsumer;

/**
 * 基于“分离链接法”实现的 Hash Table
 *
 * @author zifangsky
 * @date 2018/12/11
 * @since 1.0.0
 */
public class SeparateChainHashTable<K, V> implements Map<K, V>{
    /**
     * 默认的平衡因子
     */
    static final float DEFAULT_LOAD_FACTOR = 0.75f;
    /**
     * 默认的表大小
     */
    static final int DEFAULT_INITIAL_CAPACITY = 16;
    /**
     * 最大容量
     */
    static final int MAXIMUM_CAPACITY = 1 << 30;

    /**
     * 单个节点信息
     */
    static class Node<K, V> {
        /**
         * key的hash
         */
        final int hash;
        /**
         * key
         */
        final K key;
        /**
         * value
         */
        V value;
        /**
         * 下一个节点
         */
        Node<K,V> next;

        public Node(int hash, K key, V value) {
            this.hash = hash;
            this.key = key;
            this.value = value;
            this.next = null;
        }

        public Node(int hash, K key, V value, Node<K,V> next) {
            this.hash = hash;
            this.key = key;
            this.value = value;
            this.next = next;
        }

        @Override
        public final String toString() {
            return key + "=" + value;
        }

        @Override
        public final boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            Node<?, ?> node = (Node<?, ?>) o;
            return SeparateChainHashTable.equals(key, node.key) &&
                    SeparateChainHashTable.equals(value, node.value);
        }

        /**
         * 取 key 和 value 的 hashCode 的异或
         */
        @Override
        public final int hashCode() {
            return SeparateChainHashTable.hashCode(key) ^ SeparateChainHashTable.hashCode(value);
        }
    }

    /* ---------------- Fields -------------- */

    /**
     * 平衡因子
     */
    private final float loadFactor;

    /**
     * 需要再次扩容的数据量
     */
    private int threshold;
    /**
     * 当前 Hash Table 存储的键值对的数量
     */
    private int size;

    /**
     * 存储数据的数组
     */
    private Node<K,V>[] table;

    /**
     * 用于遍历所有元素
     */
    private Set<Node<K,V>> entrySet;

    public SeparateChainHashTable(){
        this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR);
    }

    public SeparateChainHashTable(int initialCapacity){
        this(initialCapacity, DEFAULT_LOAD_FACTOR);
    }

    public SeparateChainHashTable(int initialCapacity, float loadFactor) {
        if (initialCapacity < 0) {
            throw new IllegalArgumentException("Illegal initial capacity: " +
                    initialCapacity);
        }
        if (loadFactor <= 0 || Float.isNaN(loadFactor)) {
            throw new IllegalArgumentException("Illegal load factor: " +
                    loadFactor);
        }

        this.loadFactor = loadFactor;
        //初始化threshold（为了后面初始化数组，这里先临时存储，不考虑平衡因子）
        this.threshold = this.tableSizeFor(initialCapacity);
    }

    public static int hashCode(Object o) {
        return o != null ? o.hashCode() : 0;
    }

    public static boolean equals(Object a, Object b) {
        return (a == b) || (a != null && a.equals(b));
    }

    /**
     * 返回当前 Hash Table 存储的键值对的数量
     */
    @Override
    public int size() {
        return size;
    }

    /**
     * 返回当前 Hash Table 是否为空
     */
    @Override
    public boolean isEmpty() {
        return size == 0;
    }

    /**
     * 返回当前 Hash Table 是否包含某个KEY
     * @author zifangsky
     * @date 2018/12/11 19:49
     * @since 1.0.0
     * @param key KEY
     * @return V
     */
    @Override
    public boolean containsKey(K key){
        return this.getNode(key) != null;
    }

    /**
     * 根据KEY查找VALUE
     * @author zifangsky
     * @date 2018/12/11 19:49
     * @since 1.0.0
     * @param key KEY
     * @return V
     */
    @Override
    public V get(K key){
        Node<K,V> temp;
        return (temp = this.getNode(key)) != null ? temp.value : null;

    }

    /**
     * 存储某个键值对
     * @author zifangsky
     * @date 2018/12/12 14:49
     * @since 1.0.0
     * @param key KEY
     * @param value VALUE
     */
    @Override
    public void put(K key, V value){
        if (key == null) {
            throw new NullPointerException("key不能为空");
        }

        this.putVal(key, value);
    }

    /**
     * 移除某个键值对
     * @author zifangsky
     * @date 2018/12/12 14:49
     * @since 1.0.0
     * @param key KEY
     */
    @Override
    public void remove(K key){
        this.removeVal(key);
    }

    /**
     * 清空Hash Table
     * @author zifangsky
     * @date 2018/12/12 14:50
     * @since 1.0.0
     */
    @Override
    public void clear(){
        Node<K,V>[] array = table;

        if(array != null && size > 0){
            size = 0;
            for(int i = 0; i< array.length; i++){
                array[i] = null;
            }
        }
    }

    /**
     * 遍历Hash Table
     * @author zifangsky
     * @date 2018/12/12 14:55
     * @since 1.0.0
     * @param action action
     */
    @Override
    public void forEach(BiConsumer<? super K, ? super V> action) {
        if (action == null) {
            throw new NullPointerException("action不能为空");
        }

        Node<K,V>[] array = table;

        if(array != null){
            for(int i = 0; i< array.length; i++){
                for(Node<K,V> tempNode = array[i]; tempNode != null; tempNode = tempNode.next){
                    action.accept(tempNode.key, tempNode.value);
                }
            }
        }
    }

    /**
     * 计算 key 的hashCode
     * <p>算法思路是将key的高16位和低16位进行异或</p>
     * <p>这么做的目的是：后面使用 <b>(n - 1) & hash</b> 计算 key 应该在数组的哪个位置，可以发现这个操作总是在比较 n-1 的hash与 key 的hash的低位。
     * 如果在进行这个操作之前将key的高16位和低16位进行异或，那么冲突将会大大减少（并且也没有新增太多其他代价）</p>
     */
    private int hash(K key){
        int temp;
        return (key == null) ? 0 : (temp = key.hashCode()) ^ (temp >>> 16);
    }

    /**
     * 用于计算某个 hash 应该对应于数组的哪个位置
     */
    private int indexFor(int hash, int length){
        return (length - 1) & hash;
    }

    /**
     * 根据给定的容量，返回比之稍大的2的倍数的容量
     */
    private int tableSizeFor(int capacity){
        int n = capacity - 1;
        if(n < 0){
            return 1;
        }else{
            n = n | (n >>> 1);
            n = n | (n >>> 2);
            n = n | (n >>> 4);
            n = n | (n >>> 8);
            n = n | (n >>> 16);

            return n >= MAXIMUM_CAPACITY ? MAXIMUM_CAPACITY : n + 1;
        }
    }

    /**
     * 根据 key 查找对应节点
     */
    private Node<K,V> getNode(K key){
        if(table == null || table.length == 0){
            //扩容
            resize();
        }

        //计算 key 的hash
        int hash = this.hash(key);
        //计算 key 在数组的位置
        int point = this.indexFor(hash, table.length);

        Node<K,V>[] array = table;
        //每个链表的首节点
        Node<K,V> first;
        //临时节点
        Node<K,V> tempNode;
        //临时key
        K tempKey;

        if(array != null && array.length > 0 && (first = array[point]) != null){
            tempNode = first;

            //遍历链表，查找 key 所在节点
            while(tempNode != null){
                if(tempNode.hash == hash && ((tempKey = tempNode.key) == key || key.equals(tempKey))){
                    return tempNode;
                }
                tempNode = tempNode.next;
            }
        }

        return null;
    }

    /**
     * 存储某个键值对
     */
    private void putVal(K key, V value){
        Node<K,V>[] array;
        Node<K,V> first;

        if((array = table) == null || array.length == 0){
            //扩容
            array = resize();
        }

        //key的hash
        int hash = this.hash(key);
        //计算某个key应该对应于数组的哪个位置
        int index = this.indexFor(hash, array.length);

        //如果对应的数组位置为空
        if((first = array[index]) == null){
            array[index] = new Node<>(hash, key, value);
            size++;
        }
        //目标键值对在首节点
        else if(first.hash == hash && (first.key == key || key.equals(first.key))){
            first.value = value;
        }
        //继续检查链表上的其他节点
        else{
            //用于标识目标键值对是否已经存在
            boolean exist = false;
            Node<K,V> tempNode = first;
            Node<K,V> nextNode = tempNode.next;

            while (nextNode != null){
                K tempKey = nextNode.key;
                //已经存在目标键值对，则替换掉原节点值
                if(nextNode.hash == hash && (tempKey == key || key.equals(tempKey))){
                    nextNode.value = value;
                    exist = true;
                }
                tempNode = tempNode.next;
                nextNode = nextNode.next;
            }

            //链表不存在目标键值对，则在链表末尾建立新节点
            if(!exist){
                tempNode.next = new Node<>(hash, key, value);
                size++;
            }
        }

        //如果当前数据太多，则需要扩容
        if(size > threshold){
            resize();
        }
    }

    /**
     * 移除某个键值对
     */
    private void removeVal(K key){
        Node<K,V>[] array;
        Node<K,V> first;

        if((array = table) != null && array.length > 0){
            //key的hash
            int hash = this.hash(key);
            //计算某个key应该对应于数组的哪个位置
            int index = this.indexFor(hash, array.length);

            if((first = array[index]) != null){
                //如果首节点就是目标节点
                if(first.hash == hash && (first.key == key || key.equals(first.key))){
                    array[index] = first.next;
                    size--;
                }
                //否则就在链表上查找
                else{
                    Node<K,V> tempNode = first;
                    while (tempNode.next != null){
                        K tempKey = tempNode.next.key;
                        //已经存在目标键值对，则替换掉原节点值
                        if(tempNode.next.hash == hash && (tempKey == key || key.equals(tempKey))){
                            //将目标节点移除
                            tempNode.next = tempNode.next.next;
                            size--;
                            break;
                        }
                        tempNode = tempNode.next;
                    }
                }
            }
        }
    }

    /**
     * 扩容并将旧Hash Table的数据转移到新的Hash Table
     */
    private Node<K,V>[] resize(){
        //旧Table
        Node<K,V>[] oldTable = table;
        //旧的需要扩容的容量
        int oldThreshold = threshold;
        //旧的数组长度标识
        int oldLength = (oldTable == null) ? 0 : oldTable.length;
        //新的数组长度标识
        int newLength;
        //新的需要扩容的容量
        int newThreshold = 0;

        //说明本次是初始化Hash Table
        if(oldLength == 0){
            newLength = oldThreshold;

            float temp = newLength * loadFactor;
            newThreshold = (temp < (float)MAXIMUM_CAPACITY ? (int)temp : Integer.MAX_VALUE);
        }
        //说明本次不是初始化Hash Table，直接扩容
        else{
            //已经达到最大，则不再扩容
            if(oldLength >= MAXIMUM_CAPACITY){
                threshold = Integer.MAX_VALUE;
                return oldTable;
            }
            if((newLength = oldLength << 1) < MAXIMUM_CAPACITY){
                //扩大两倍
                newThreshold = oldThreshold << 1;
            }else{
                //扩大到最大
                newLength = MAXIMUM_CAPACITY;
                newThreshold = Integer.MAX_VALUE;
            }
        }

        threshold = newThreshold;
        Node<K,V>[] newTable = new Node[newLength];
        table = newTable;

        //如果原数组存在数据，则需要将原数组里面的数据挪到新数组
        if(oldTable != null){
//            this.oldTransfer(oldTable, newTable);
            this.transfer(oldTable, newTable);
        }

        return newTable;
    }

    /**
     * 转移旧节点（一般思路）
     */
    private void oldTransfer(Node<K,V>[] oldTable, Node<K,V>[] newTable){
        for(int i = 0; i < oldTable.length; i++){
            Node<K,V> first;
            if((first = oldTable[i]) != null){
                //释放旧数组的对象引用
                oldTable[i] = null;
                //临时节点
                Node<K,V> nextNode;

                //处理首节点
                nextNode = first.next;
                int index = this.indexFor(first.hash, newTable.length);
                first.next =newTable[index];
                newTable[index] = first;

                //通过遍历链表并转移
                while(nextNode != null){
                    index = this.indexFor(nextNode.hash, newTable.length);
                    Node<K,V> tempNode = newTable[index];
                    //每次把旧链表节点放到新链表的头节点
                    newTable[index] = nextNode;
                    nextNode = nextNode.next;
                    //然后把新链表原来的节点放到后面
                    newTable[index].next = tempNode;
                }
            }
        }
    }

    /**
     * 转移旧节点（思路来至{@link java.util.HashMap}）
     * <p>详细算法过程：</p>
     * <p>比较key的hash与旧数组长度的二进制的最高位，如果高位为0，则还是存在跟原索引相同的位置，否则存在（原索引+原数组长度）的位置</p>
     */
    private void transfer(Node<K,V>[] oldTable, Node<K,V>[] newTable){
        //旧的数组长度标识
        int oldLength = oldTable.length;

        for(int i = 0; i < oldTable.length; i++){
            //临时节点
            Node<K,V> tempNode = oldTable[i];

            if(tempNode != null){
                //如果只有一个节点，则直接移过去
                if(tempNode.next == null){
                    newTable[this.indexFor(tempNode.hash, newTable.length)] = tempNode;
                }else{
                    //低位头结点和尾节点
                    Node<K,V> lowHead = null, lowTail = null;
                    //高位头结点和尾节点
                    Node<K,V> highHead = null, highTail = null;

                    while (tempNode != null){
                        //如果高位是0，则将这个节点放到低位索引处
                        if((tempNode.hash & oldLength) == 0){
                            if(lowTail == null){
                                lowHead = tempNode;
                            }else{
                                lowTail.next = tempNode;
                            }
                            //此时tempNode变成新的低位尾节点
                            lowTail = tempNode;
                        }else{
                            if(highTail == null){
                                highHead = tempNode;
                            }else{
                                highTail.next = tempNode;
                            }
                            //此时tempNode变成新的高位尾节点
                            highTail = tempNode;
                        }

                        tempNode = tempNode.next;
                    }

                    //将两组整理好的节点链接到新数组
                    if(lowTail != null){
                        lowTail.next = null;
                        newTable[i] = lowHead;
                    }
                    if(highTail != null){
                        highTail.next = null;
                        newTable[i + oldLength] = highHead;
                    }
                }
            }
        }
    }

}
